/*************************************************************************** * Copyright 2009-2012 by Christian Ihle * * kontakt@usikkert.net * * * * This file is part of KouInject. * * * * KouInject is free software; you can redistribute it and/or modify * * it under the terms of the GNU Lesser General Public License as * * published by the Free Software Foundation, either version 3 of * * the License, or (at your option) any later version. * * * * KouInject is distributed in the hope that it will be useful, * * but WITHOUT ANY WARRANTY; without even the implied warranty of * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * * Lesser General Public License for more details. * * * * You should have received a copy of the GNU Lesser General Public * * License along with KouInject. * * If not, see <http://www.gnu.org/licenses/>. * ***************************************************************************/ package net.usikkert.kouinject; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.net.URL; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Enumeration; import java.util.HashSet; import java.util.Set; import java.util.jar.JarEntry; import java.util.jar.JarInputStream; import java.util.logging.Logger; import net.usikkert.kouinject.util.ReflectionUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.Validate; /** * Finds classes by scanning the classpath. Classes are searched for in the file system and in * jar-files. * * TODO: refactor into smaller pieces. * * @author Christian Ihle */ public class ClassPathScanner implements ClassLocator { private static final Logger LOG = Logger.getLogger(ClassPathScanner.class.getName()); private final ReflectionUtils reflectionUtils = new ReflectionUtils(); private final ClassLoader classLoader; /** * Creates a new classpath scanner using the threads context classloader, or the classloader for this class. */ public ClassPathScanner() { this.classLoader = getClassLoader(); } /** * Creates a new classpath scanner using the specified classloader. * * @param classLoader The classloader to use to scan for classes, and load the classes. */ public ClassPathScanner(final ClassLoader classLoader) { Validate.notNull(classLoader, "Class loader can not be null"); this.classLoader = classLoader; } /** * {@inheritDoc} */ @Override public Set<Class<?>> findClasses(final String... basePackages) { Validate.notNull(basePackages, "Base packages can not be null"); final long start = System.currentTimeMillis(); final Set<Class<?>> classes = findClassesFromSetOfBasePackages(basePackages); final long stop = System.currentTimeMillis(); LOG.fine("Time spent scanning classpath: " + (stop - start) + " ms"); LOG.fine("Classes found: " + classes.size()); return classes; } private Set<Class<?>> findClassesFromSetOfBasePackages(final String... basePackages) { final Set<Class<?>> classes = new HashSet<Class<?>>(); final Set<String> basePackageSet = convertBasePackagesToSet(basePackages); for (final String basePackage : basePackageSet) { classes.addAll(findClassesFromBasePackage(basePackage)); } return classes; } private Set<String> convertBasePackagesToSet(final String... basePackages) { Validate.notNull(basePackages, "Base packages can not be null"); Validate.isTrue(basePackages.length > 0, "Must have at least one base package"); final Set<String> basePackageSet = new HashSet<String>(); for (final String basePackage : basePackages) { Validate.isTrue(StringUtils.isNotBlank(basePackage), "Base package can not be empty"); basePackageSet.add(basePackage); } return basePackageSet; } private Set<Class<?>> findClassesFromBasePackage(final String basePackage) { final Set<Class<?>> classes = new HashSet<Class<?>>(); final String path = basePackage.replace('.', '/'); try { final Enumeration<URL> resources = classLoader.getResources(path); if (resources != null) { AccessController.doPrivileged(new PrivilegedAction<Object>() { @Override public Object run() { // hasMoreElements requires java.io.FilePermission "read" to find anything while (resources.hasMoreElements()) { classes.addAll(getClassesFromResource(basePackage, path, resources)); } return null; } }); } } catch (final IOException e) { throw new RuntimeException(e); } return classes; } private Set<Class<?>> getClassesFromResource(final String basePackage, final String path, final Enumeration<URL> resources) { final Set<Class<?>> classes = new HashSet<Class<?>>(); final String filePath = getFilePath(resources.nextElement()); if (filePath != null) { if (isJarFilePath(filePath)) { final String jarPath = getJarPath(filePath); classes.addAll(getFromJARFile(jarPath, path)); } else { classes.addAll(getFromDirectory(new File(filePath), basePackage)); } } return classes; } private Set<Class<?>> getFromDirectory(final File directory, final String packageName) { final Set<Class<?>> classes = new HashSet<Class<?>>(); if (directory.exists()) { final File[] files = directory.listFiles(); for (final File file : files) { if (file.isDirectory()) { classes.addAll(getFromDirectory(file, packageName + "." + file.getName())); } else if (isClass(file.getName())) { final String className = packageName + '.' + stripFilenameExtension(file.getName()); final Class<?> clazz = loadClass(className); addClass(clazz, classes); } } } return classes; } private Set<Class<?>> getFromJARFile(final String jar, final String packageName) { final Set<Class<?>> classes = new HashSet<Class<?>>(); JarInputStream jarFile = null; try { jarFile = new JarInputStream(new FileInputStream(jar)); JarEntry jarEntry; do { jarEntry = jarFile.getNextJarEntry(); if (jarEntry != null) { final String fileName = jarEntry.getName(); if (isClass(fileName)) { final String className = stripFilenameExtension(fileName); if (className.startsWith(packageName)) { final Class<?> clazz = loadClass(className.replace('/', '.')); addClass(clazz, classes); } } } } while (jarEntry != null); } catch (final IOException e) { throw new RuntimeException(e); } finally { if (jarFile != null) { try { jarFile.close(); } catch (final IOException e) { throw new RuntimeException(e); } } } return classes; } /** * Gets the best possible classloader for scanning after classes. Usually it's the current * thread's context classloader, but if that's not available then the classloader for this class * is used instead. * * @return A usable classloader. */ private ClassLoader getClassLoader() { final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); if (contextClassLoader != null) { return contextClassLoader; } else { return getClass().getClassLoader(); } } private String getFilePath(final URL url) { final String filePath = url.getFile(); if (filePath != null) { return fixWindowsSpace(filePath); } return null; } private boolean isJarFilePath(final String filePath) { return (filePath.indexOf("!") > 0) && (filePath.indexOf(".jar") > 0); } private String fixWindowsSpace(final String filePath) { if (filePath.indexOf("%20") > 0) { return filePath.replaceAll("%20", " "); } return filePath; } private String getJarPath(final String filePath) { final String jarPath = filePath.substring(0, filePath.indexOf("!")).substring(filePath.indexOf(":") + 1); return fixWindowsJarPath(jarPath); } private String fixWindowsJarPath(final String jarPath) { if (jarPath.indexOf(":") >= 0) { return jarPath.substring(1); } return jarPath; } private static String stripFilenameExtension(final String filename) { if (filename == null) { return null; } final int dotIndex = filename.lastIndexOf("."); if (dotIndex == -1) { return filename; } return filename.substring(0, dotIndex); } private boolean isClass(final String fileName) { return fileName.endsWith(".class"); } private void addClass(final Class<?> clazz, final Set<Class<?>> classes) { if (reflectionUtils.isNormalClass(clazz)) { classes.add(clazz); } } private Class<?> loadClass(final String className) { try { return Class.forName(className, false, classLoader); } catch (final ClassNotFoundException e) { throw new RuntimeException(e); } } }